// SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <cstdint>
#include <tt-metalium/core_coord.hpp>
#include <tt-metalium/buffer.hpp>
#include "ttnn/tensor/tensor.hpp"
#include "ttnn/operations/ccl/shared_with_host/hetergeneous_data_structs.hpp"
#include <tt-metalium/constants.hpp>
#include "ttnn/operations/ccl/ccl_host_datastructures.hpp"
#include "ttnn/operations/ccl/ccl_common.hpp"
#include "ttnn/operations/ccl/ccl_op_fusion.hpp"
#include <tt-metalium/global_semaphore.hpp>
#include "ttnn/global_semaphore.hpp"

#include "strided_all_gather_async_device_operation_types.hpp"
#include "strided_all_gather_async_program.hpp"

#include "ttnn/decorators.hpp"

#include <optional>
#include <utility>
#include <vector>

namespace ttnn::operations::experimental::ccl {
namespace strided_all_gather_async {

struct StridedAllGatherAsync {
    using operation_attributes_t = strided_all_gather_async::operation_attributes_t;
    using tensor_args_t = strided_all_gather_async::tensor_args_t;
    using spec_return_value_t = strided_all_gather_async::spec_return_value_t;
    using tensor_return_value_t = strided_all_gather_async::tensor_return_value_t;

    using program_factory_t = std::variant<program::StridedAllGatherAsyncProgramFactory>;

    static program_factory_t select_program_factory(const operation_attributes_t&, const tensor_args_t&);

    static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
    static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

    static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
    static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

    static tt::tt_metal::operation::Hash compute_program_hash(const operation_attributes_t&, const tensor_args_t&);

    static std::tuple<operation_attributes_t, tensor_args_t> invoke(
        const Tensor& input_tensor,
        const std::optional<ttnn::Tensor>& persistent_output_buffer,
        uint32_t dim,
        const std::vector<GlobalSemaphore>& multi_device_global_semaphore,
        uint32_t num_links,
        const std::optional<MemoryConfig>& memory_config,
        ttnn::ccl::Topology topology,
        const std::optional<uint32_t>& cluster_axis,
        const std::optional<uint32_t>& tiles_per_chunk,
        const std::optional<uint32_t>& num_workers_per_link,
        const std::optional<uint32_t>& num_buffers_per_channel,
        const std::optional<uint32_t>& mm_cores_y,
        const std::optional<uint32_t>& mm_block_ht,
        const std::optional<uint32_t>& mm_block_wt);
};
}  // namespace strided_all_gather_async
}  // namespace ttnn::operations::experimental::ccl

namespace ttnn::prim {

constexpr auto strided_all_gather_async = ttnn::register_operation<
    "ttnn::prim::strided_all_gather_async",
    ttnn::operations::experimental::ccl::strided_all_gather_async::StridedAllGatherAsync>();

}  // namespace ttnn::prim
